import tensorflow as tf
"""
反池化原理
反池化是池化的逆操作，无法通过池化的结果还原出全部的原始数据。
池化的过程只是保留主要的信息，去除部分信息。想从池化后的这些主要信息恢复
全部信息，则会存在信息缺失，这时只能通过补位来实现最大程度的信息完整。

池化有最大池化与平均池化，则反池化也会对应这两个操作：
1) 平均池化首先还原成原来的大小，然后将池化结果中的每个值都填入其对应于
   原始数据区域中的相应位置
2）最大池化要求在池化过程中记录最大激活的坐标位置，然后在反池化时，
  只把池化过程中最大值所在位置坐标的值激活，其他的值置为0
  
  TensorFlow中并没有反池化的操作函数。也不支持输出最大激活值的位置，
  但是有个池化的反向传播函数 tf.nn.max_pool_with_argmax,该函数可以
  输出位置
"""


# 重新定义最大池化函数
def max_pool_with_argmax(net, stride, padding='SAME'):
    """
    函数首先调用 max_pool_with_argmax 函数获得每个最大值的位置 mask，
    再将反向传播的 mask 梯度计算停止，接着调用 tf.nn.max_pool 函数
    计算最大池化操作，然后返回结果
    """
    _, _mask_ = tf.nn.max_pool_with_argmax(net, ksize=[1, stride, stride, 1],
                                           strides=[1, stride, stride, 1], padding=padding)
    _mask_ = tf.stop_gradient(_mask_)
    net = tf.nn.max_pool(net, ksize=[1, stride, stride, 1], strides=[1, stride, stride, 1], padding=padding)
    return net, _mask_


def unpool(net, mask, stride):
    """
    定义最大反池化函数
    """
    ksize = [1, stride, stride, 1]
    input_shape = net.get_shape().as_list()
    # 计算 new shape
    output_shape = (input_shape[0], input_shape[1]*ksize[1], input_shape[2]*ksize[2], input_shape[3])
    # 计算索引
    one_like_mask = tf.ones_like(mask)
    batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1])
    b = one_like_mask*batch_range
    y = mask // (output_shape[2]*output_shape[3])
    x = mask % (output_shape[2]*output_shape[3]) // output_shape[3]
    feature_range = tf.range(output_shape[3], dtype=tf.int64)
    f = one_like_mask * feature_range
    # 转置索引
    updates_size = tf.size(net)
    indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
    values = tf.reshape(net, [updates_size])
    ret = tf.scatter_nd(indices, values, output_shape)
    return ret


# 测试
img = tf.constant([
    [[0., 4.], [0., 4.], [0., 4.], [0., 4.]],
    [[1., 5.], [1., 5.], [1., 5.], [1., 5.]],
    [[2., 6.], [2., 6.], [2., 6.], [2., 6.]],
    [[3., 7.], [3., 7.], [3., 7.], [3., 7.]]
])

img = tf.reshape(img, [1, 4, 4, 2])
pooling = tf.nn.max_pool(img, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME')
encode, mask = max_pool_with_argmax(img, 2)

img2 = unpool(encode, mask, 2)

with tf.Session() as sess:
    print('image:\n', sess.run(img))
    print('pooling:\n', sess.run(pooling))
    result, mask0 = sess.run([encode, mask])
    print('result:\n', result)
    print('mask0:\n', mask0)

    print('img2:\n', sess.run(img2))


# pooling:
#  [[
#    [[1. 5.] [1. 5.]]
#    [[3. 7.] [3. 7.]]
#  ]]
# result:
#  [[
#    [[1. 5.] [1. 5.]]
#    [[3. 7.] [3. 7.]]
#  ]]
# mask0:
#  [[
#    [[ 8  9] [12 13]]
#    [[24 25] [28 29]]
#  ]]
# img2:
#  [[
#    [[0. 0.] [0. 0.] [0. 0.] [0. 0.]]
#    [[1. 5.] [0. 0.] [1. 5.] [0. 0.]]
#    [[0. 0.] [0. 0.] [0. 0.] [0. 0.]]
#    [[3. 7.] [0. 0.] [3. 7.] [0. 0.]]
#  ]]
